{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "VIDEO_PATH = './data/scene-2.mov'\n",
    "SAVE_DIR = \"./data/processed/scene-2/\"\n",
    "VIDEO_OUT = './data/processed/scene_2.mp4'\n",
    "module_handle = \"https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import plotly.express as px\n",
    "import os\n",
    "from tqdm.notebook import tqdm\n",
    "import tensorflow_hub as hub\n",
    "import tensorflow as tf\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For drawing onto the image.\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from PIL import ImageColor\n",
    "from PIL import ImageDraw\n",
    "from PIL import ImageFont\n",
    "from PIL import ImageOps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_bounding_box_on_image(image,\n",
    "                               ymin,\n",
    "                               xmin,\n",
    "                               ymax,\n",
    "                               xmax,\n",
    "                               color,\n",
    "                               font,\n",
    "                               thickness=4,\n",
    "                               display_str_list=()):\n",
    "  \"\"\"Adds a bounding box to an image.\"\"\"\n",
    "  draw = ImageDraw.Draw(image)\n",
    "  im_width, im_height = image.size\n",
    "  (left, right, top, bottom) = (xmin * im_width, xmax * im_width,\n",
    "                                ymin * im_height, ymax * im_height)\n",
    "  draw.line([(left, top), (left, bottom), (right, bottom), (right, top),\n",
    "             (left, top)],\n",
    "            width=thickness,\n",
    "            fill=color)\n",
    "\n",
    "  # If the total height of the display strings added to the top of the bounding\n",
    "  # box exceeds the top of the image, stack the strings below the bounding box\n",
    "  # instead of above.\n",
    "  display_str_heights = [font.getsize(ds)[1] for ds in display_str_list]\n",
    "  # Each display_str has a top and bottom margin of 0.05x.\n",
    "  total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)\n",
    "\n",
    "  if top > total_display_str_height:\n",
    "    text_bottom = top\n",
    "  else:\n",
    "    text_bottom = top + total_display_str_height\n",
    "  # Reverse list and print from bottom to top.\n",
    "  for display_str in display_str_list[::-1]:\n",
    "    text_width, text_height = font.getsize(display_str)\n",
    "    margin = np.ceil(0.05 * text_height)\n",
    "    draw.rectangle([(left, text_bottom - text_height - 2 * margin),\n",
    "                    (left + text_width, text_bottom)],\n",
    "                   fill=color)\n",
    "    draw.text((left + margin, text_bottom - text_height - margin),\n",
    "              display_str,\n",
    "              fill=\"black\",\n",
    "              font=font)\n",
    "    text_bottom -= text_height - 2 * margin\n",
    "\n",
    "\n",
    "def draw_boxes(image, boxes, class_ids, class_names, scores, font, max_boxes=10, min_score=0.1):\n",
    "  \"\"\"Overlay labeled boxes on an image with formatted scores and label names.\"\"\"\n",
    "  colors = list(ImageColor.colormap.values())\n",
    "\n",
    "  for i in range(min(boxes.shape[0], max_boxes)):\n",
    "    if scores[i] >= min_score:\n",
    "      ymin, xmin, ymax, xmax = tuple(boxes[i])\n",
    "      display_str = \"{}: {}%\".format(class_names[i].decode(\"ascii\"),\n",
    "                                     int(100 * scores[i]))\n",
    "      color = colors[class_ids[i] % len(colors)]\n",
    "      image_pil = Image.fromarray(np.uint8(image)).convert(\"RGB\")\n",
    "      draw_bounding_box_on_image(\n",
    "          image_pil,\n",
    "          ymin,\n",
    "          xmin,\n",
    "          ymax,\n",
    "          xmax,\n",
    "          color,\n",
    "          font,\n",
    "          display_str_list=[display_str])\n",
    "      np.copyto(image, np.array(image_pil))\n",
    "  return np.array(image_pil)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fast_draw_boxes(image, boxes, class_ids, class_names, scores, font, max_boxes=10, min_score=0.1):\n",
    "  \"\"\"Overlay labeled boxes on an image with formatted scores and label names.\"\"\"\n",
    "  colors = list(ImageColor.colormap.values())\n",
    "  image_pil = Image.fromarray(np.uint8(image)).convert(\"RGB\")\n",
    "\n",
    "  for i in range(min(boxes.shape[0], max_boxes)):\n",
    "    if scores[i] >= min_score:\n",
    "      ymin, xmin, ymax, xmax = tuple(boxes[i])\n",
    "      display_str = \"{}: {}%\".format(class_names[i].decode(\"ascii\"),\n",
    "                                     int(100 * scores[i]))\n",
    "      color = colors[class_ids[i] % len(colors)]\n",
    "      draw_bounding_box_on_image(\n",
    "          image_pil,\n",
    "          ymin,\n",
    "          xmin,\n",
    "          ymax,\n",
    "          xmax,\n",
    "          color,\n",
    "          font,\n",
    "          display_str_list=[display_str])\n",
    "#       np.copyto(image, np.array(image_pil))\n",
    "  return np.array(image_pil)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 8.07 s, sys: 659 ms, total: 8.73 s\n",
      "Wall time: 3.91 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "frames = []\n",
    "\n",
    "cap = cv2.VideoCapture(VIDEO_PATH)\n",
    "ret = True\n",
    "\n",
    "while ret:\n",
    "    ret, frame = cap.read()\n",
    "    if ret:\n",
    "        frame = cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE)\n",
    "        frames.append(frame)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n"
     ]
    }
   ],
   "source": [
    "detector = hub.load(module_handle).signatures['default']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    font = ImageFont.truetype(\"/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf\",\n",
    "                          25)\n",
    "except IOError:\n",
    "    print(\"Font not found, using default font.\")\n",
    "    font = ImageFont.load_default()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'frames' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-8-46b4e23018e1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0msample_rate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0msample_rate\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m         \u001b[0mresized\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcv2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m512\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m512\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'frames' is not defined"
     ]
    }
   ],
   "source": [
    "processed_frames = []\n",
    "sample_rate = 3\n",
    "\n",
    "for i, img in enumerate(tqdm(frames)):\n",
    "    if i % sample_rate == 0:\n",
    "        resized = cv2.resize(img, (512, 512))\n",
    "        image_tensor = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]\n",
    "\n",
    "        result = detector(image_tensor)\n",
    "\n",
    "    image_with_boxes = fast_draw_boxes(\n",
    "        img.copy(), \n",
    "        result[\"detection_boxes\"].numpy(),\n",
    "        result['detection_class_labels'].numpy(),\n",
    "        result[\"detection_class_entities\"].numpy(), \n",
    "        result[\"detection_scores\"].numpy(),\n",
    "        font=font\n",
    "    )\n",
    "    \n",
    "    processed_frames.append(image_with_boxes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a3186f6d579f4cad8038f39046924e10",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1203.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "CPU times: user 17 s, sys: 365 ms, total: 17.4 s\n",
      "Wall time: 17.4 s\n"
     ]
    }
   ],
   "source": [
    "os.makedirs(SAVE_DIR, exist_ok=True)\n",
    "for i, frame in enumerate(tqdm(processed_frames)):\n",
    "    cv2.imwrite(os.path.join(SAVE_DIR, f\"frame-{i}.jpg\"), frame, [cv2.IMWRITE_JPEG_QUALITY, 90])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3788889a0b3a46279a91fbf6d57fa102",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1203.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Define the codec and create VideoWriter object\n",
    "fourcc = cv2.VideoWriter_fourcc(*'MP4V')\n",
    "out = cv2.VideoWriter(VIDEO_OUT,fourcc, 30, (1280, 720))\n",
    "\n",
    "for frame in tqdm(processed_frames):\n",
    "    out.write(frame)\n",
    "\n",
    "out.release()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
